In this project, we will use transfer learning to predict the classes for a subset of images from the Caltech-101 dataset. The reason we use transfer learning is because the dataset has very few images (50 per class), so traditional machine learning techniques would not be as effective.

First, we need to import some library.

library(readr)
library(ggplot2)
library(dplyr)
library(methods)
library(stringi)
library(keras)
library(glmnet)

Then, we import the Caltech-101 dataset.

input_dir <- "dataset"

image_paths <- dir(input_dir, recursive = TRUE)
ext <- stri_match(image_paths, regex = "\\.([A-Za-z]+$)")[,2]
image_paths <- image_paths[stri_trans_tolower(ext) %in% c("jpg", "png", "jpeg")]
class_vector <- dirname(image_paths)
class_names <- levels(factor(class_vector))

n <- length(class_vector)
Z <- array(0, dim = c(n, 224, 224, 3))
y <- as.numeric(factor(class_vector)) - 1L
for (i in seq_len(n))
{
  pt <- file.path(input_dir, image_paths[i])
  image <- image_to_array(image_load(pt, target_size = c(224,224)))
  Z[i,,,] <- array_reshape(image, c(1, dim(image)))
}
# permute
set.seed(1)
index <- sample(seq_len(nrow(Z)))
Z <- Z[index,,,]
y <- y[index]

The dataset contains 20 classes, each having about 50-60 images. Let’s look at some examples of each class:

par(mar = c(0,0,0,0))
par(mfrow = c(4, 5))
set.seed(1)
for (i in 0:19) {
  plot(0,0,xlim=c(0,1),ylim=c(0,1),axes= FALSE,type = "n")
  j <- sample(which(y == i), 1)
  rasterImage(Z[j,,,]/255,0,0,1,1)
  text(0.5, 0.1, class_names[i+1], cex = 2, col = "red")
}

Before we jump into transfer learning, let’s try to train a simple model first and see how well it performs! Z is a matrix containing our images. First split Z into train and valid set:

z_train_id <- sample(c("train", "valid"), nrow(Z), TRUE, prob = c(0.6, 0.4))

Z_train <- Z[z_train_id == "train",,,]                  # Note: X is a matrix
y_train <- to_categorical(y[z_train_id == "train"])

Fit a simple convolutional model to Z:

simple_model <- keras_model_sequential()
simple_model %>%
  layer_conv_2d(filters = 16, kernel_size = c(3, 3),
                  input_shape = dim(Z_train)[-1],
                  padding = "same") %>%
  layer_max_pooling_2d(pool_size = c(2, 2)) %>%
  layer_activation(activation = "relu") %>%
  layer_conv_2d(filters = 16, kernel_size = c(3, 3),
                  padding = "same") %>%
  layer_max_pooling_2d(pool_size = c(2, 2)) %>%
  layer_activation(activation = "relu") %>%
  layer_conv_2d(filters = 16, kernel_size = c(3, 3),
                  padding = "same") %>%
  layer_max_pooling_2d(pool_size = c(2, 2)) %>%
  layer_activation(activation = "relu") %>%
  layer_flatten() %>%
  layer_dense(units = ncol(y_train)) %>%
  layer_activation(activation = "softmax")
simple_model %>% compile(loss = 'categorical_crossentropy',
                  optimizer = optimizer_sgd(lr = 0.01, momentum = 0.8),
                  metrics = c('accuracy'))
simple_model
## Model
## ___________________________________________________________________________
## Layer (type)                     Output Shape                  Param #     
## ===========================================================================
## conv2d_1 (Conv2D)                (None, 224, 224, 16)          448         
## ___________________________________________________________________________
## max_pooling2d_1 (MaxPooling2D)   (None, 112, 112, 16)          0           
## ___________________________________________________________________________
## activation_1 (Activation)        (None, 112, 112, 16)          0           
## ___________________________________________________________________________
## conv2d_2 (Conv2D)                (None, 112, 112, 16)          2320        
## ___________________________________________________________________________
## max_pooling2d_2 (MaxPooling2D)   (None, 56, 56, 16)            0           
## ___________________________________________________________________________
## activation_2 (Activation)        (None, 56, 56, 16)            0           
## ___________________________________________________________________________
## conv2d_3 (Conv2D)                (None, 56, 56, 16)            2320        
## ___________________________________________________________________________
## max_pooling2d_3 (MaxPooling2D)   (None, 28, 28, 16)            0           
## ___________________________________________________________________________
## activation_3 (Activation)        (None, 28, 28, 16)            0           
## ___________________________________________________________________________
## flatten_1 (Flatten)              (None, 12544)                 0           
## ___________________________________________________________________________
## dense_1 (Dense)                  (None, 20)                    250900      
## ___________________________________________________________________________
## activation_4 (Activation)        (None, 20)                    0           
## ===========================================================================
## Total params: 255,988
## Trainable params: 255,988
## Non-trainable params: 0
## ___________________________________________________________________________

Train the simple model

history <- simple_model %>% fit(Z_train, y_train, epochs = 10,
      validation_split = 0.1)
plot(history)

Let’s look at how our simple model performs

simple_y_pred <- predict_classes(simple_model, Z)
tapply(y == simple_y_pred, z_train_id, mean)
##      train      valid 
## 0.06636501 0.03800475

This sucks! The model barely learn anything at all. Let’s see if we can do better with transfer learning.

For the transfer learning task, we will use the ResNet50 pre-trained model, which was trained on over a million images from the ImageNet database. The network is 50 layers deep can classify images into 1000 object categories.

We will use ResNet50 with the last layer excluded to embed the images into a denser representation (2048 dimensional vectors). This will make our classification significantly easier.

Import pre-trained model, grab the second last layer

resnet50 <- application_resnet50(weights = 'imagenet', include_top = TRUE)
model_avg_pool <- keras_model(inputs = resnet50$input,
                              outputs = get_layer(resnet50, 'avg_pool')$output)

Embed image using the pre-trained model

X_embedded <- predict(model_avg_pool, x = imagenet_preprocess_input(Z), verbose = TRUE)
dim(X_embedded)
## [1] 1084    1    1 2048
X = drop(X_embedded)
dim(X)
## [1] 1084 2048

Now that we have embedded our images into 2048 dimensional vectors, we can perform classification taking in those vectors as input.

We split X into train and valid set, 60/40 split:

train_id <- sample(c("train", "valid"), nrow(X), TRUE, prob = c(0.6, 0.4))

X_train <- X[train_id == "train",]                  # Note: X is a matrix
y_train <- to_categorical(y[train_id == "train"])

Then we train a model on the embedded corpus:

model <- keras_model_sequential()
model %>%
  layer_dense(units = 256, input_shape = ncol(X_train)) %>%
  layer_activation(activation = "relu") %>%
  layer_dropout(rate = 0.5) %>%

  layer_dense(units = 256) %>%
  layer_activation(activation = "relu") %>%
  layer_dropout(rate = 0.5) %>%

  layer_dense(units = ncol(y_train)) %>%
  layer_activation(activation = "softmax")
model %>% compile(loss = 'categorical_crossentropy',
                  optimizer = optimizer_rmsprop(lr = 0.0005),
                  metrics = c('accuracy'))
history <- model %>%
  fit(X_train, y_train, epochs = 10)
plot(history)

Here is the accuracy of the model on both train and validation sets:

y_pred <- predict_classes(model, X)
tapply(y == y_pred, train_id, mean)
##     train     valid 
## 1.0000000 0.9689737

This works much better than the simple model! We can see which classes are more easily misclassified using the confusion matrix:

table(value = class_names[y + 1L], prediction = class_names[y_pred + 1L], train_id)
## , , train_id = train
## 
##                prediction
## value           crab cup helicopter lobster lotus mandolin mayfly pigeon
##   crab            38   0          0       0     0        0      0      0
##   cup              0  36          0       0     0        0      0      0
##   helicopter       0   0         51       0     0        0      0      0
##   lobster          0   0          0      29     0        0      0      0
##   lotus            0   0          0       0    42        0      0      0
##   mandolin         0   0          0       0     0       23      0      0
##   mayfly           0   0          0       0     0        0     29      0
##   pigeon           0   0          0       0     0        0      0     29
##   pizza            0   0          0       0     0        0      0      0
##   platypus         0   0          0       0     0        0      0      0
##   pyramid          0   0          0       0     0        0      0      0
##   revolver         0   0          0       0     0        0      0      0
##   rhino            0   0          0       0     0        0      0      0
##   rooster          0   0          0       0     0        0      0      0
##   saxophone        0   0          0       0     0        0      0      0
##   schooner         0   0          0       0     0        0      0      0
##   scissors         0   0          0       0     0        0      0      0
##   windsor_chair    0   0          0       0     0        0      0      0
##   wrench           0   0          0       0     0        0      0      0
##   yin_yang         0   0          0       0     0        0      0      0
##                prediction
## value           pizza platypus pyramid revolver rhino rooster saxophone
##   crab              0        0       0        0     0       0         0
##   cup               0        0       0        0     0       0         0
##   helicopter        0        0       0        0     0       0         0
##   lobster           0        0       0        0     0       0         0
##   lotus             0        0       0        0     0       0         0
##   mandolin          0        0       0        0     0       0         0
##   mayfly            0        0       0        0     0       0         0
##   pigeon            0        0       0        0     0       0         0
##   pizza            33        0       0        0     0       0         0
##   platypus          0       22       0        0     0       0         0
##   pyramid           0        0      38        0     0       0         0
##   revolver          0        0       0       44     0       0         0
##   rhino             0        0       0        0    33       0         0
##   rooster           0        0       0        0     0      27         0
##   saxophone         0        0       0        0     0       0        24
##   schooner          0        0       0        0     0       0         0
##   scissors          0        0       0        0     0       0         0
##   windsor_chair     0        0       0        0     0       0         0
##   wrench            0        0       0        0     0       0         0
##   yin_yang          0        0       0        0     0       0         0
##                prediction
## value           schooner scissors windsor_chair wrench yin_yang
##   crab                 0        0             0      0        0
##   cup                  0        0             0      0        0
##   helicopter           0        0             0      0        0
##   lobster              0        0             0      0        0
##   lotus                0        0             0      0        0
##   mandolin             0        0             0      0        0
##   mayfly               0        0             0      0        0
##   pigeon               0        0             0      0        0
##   pizza                0        0             0      0        0
##   platypus             0        0             0      0        0
##   pyramid              0        0             0      0        0
##   revolver             0        0             0      0        0
##   rhino                0        0             0      0        0
##   rooster              0        0             0      0        0
##   saxophone            0        0             0      0        0
##   schooner            38        0             0      0        0
##   scissors             0       27             0      0        0
##   windsor_chair        0        0            38      0        0
##   wrench               0        0             0     27        0
##   yin_yang             0        0             0      0       37
## 
## , , train_id = valid
## 
##                prediction
## value           crab cup helicopter lobster lotus mandolin mayfly pigeon
##   crab            34   0          0       1     0        0      0      0
##   cup              0  18          0       0     0        0      0      0
##   helicopter       0   0         36       0     0        0      0      0
##   lobster          0   0          1      11     0        0      0      0
##   lotus            0   0          0       0    24        0      0      0
##   mandolin         0   0          0       0     0       19      0      0
##   mayfly           0   0          0       0     0        0     11      0
##   pigeon           0   0          0       0     0        0      0     16
##   pizza            0   0          0       0     0        0      0      0
##   platypus         1   0          0       0     0        0      0      0
##   pyramid          0   0          0       0     0        0      0      0
##   revolver         0   0          1       0     0        0      0      0
##   rhino            0   0          0       0     0        0      0      0
##   rooster          0   0          0       0     0        0      0      1
##   saxophone        0   0          0       0     0        0      0      0
##   schooner         0   0          0       0     0        0      0      0
##   scissors         0   0          0       0     0        0      0      0
##   windsor_chair    0   0          0       0     0        0      0      0
##   wrench           0   0          0       1     0        0      0      0
##   yin_yang         0   0          0       0     0        0      0      0
##                prediction
## value           pizza platypus pyramid revolver rhino rooster saxophone
##   crab              0        0       0        0     0       0         0
##   cup               0        0       0        0     0       1         0
##   helicopter        0        0       0        1     0       0         0
##   lobster           0        0       0        0     0       0         0
##   lotus             0        0       0        0     0       0         0
##   mandolin          0        0       0        0     0       0         0
##   mayfly            0        0       0        0     0       0         0
##   pigeon            0        0       0        0     0       0         0
##   pizza            20        0       0        0     0       0         0
##   platypus          0       11       0        0     0       0         0
##   pyramid           0        0      19        0     0       0         0
##   revolver          0        0       0       36     0       0         0
##   rhino             0        0       0        0    26       0         0
##   rooster           0        0       0        0     0      21         0
##   saxophone         0        0       0        0     0       0        16
##   schooner          0        0       0        0     0       0         0
##   scissors          0        0       0        0     0       0         0
##   windsor_chair     0        0       0        0     0       0         0
##   wrench            0        0       0        0     0       0         0
##   yin_yang          0        0       0        0     0       0         0
##                prediction
## value           schooner scissors windsor_chair wrench yin_yang
##   crab                 0        0             0      0        0
##   cup                  0        2             0      0        0
##   helicopter           0        0             0      0        0
##   lobster              0        0             0      0        0
##   lotus                0        0             0      0        0
##   mandolin             0        0             0      1        0
##   mayfly               0        0             0      0        0
##   pigeon               0        0             0      0        0
##   pizza                0        0             0      0        0
##   platypus             0        0             0      0        0
##   pyramid              0        0             0      0        0
##   revolver             0        0             0      1        0
##   rhino                0        0             0      0        0
##   rooster              0        0             0      0        0
##   saxophone            0        0             0      0        0
##   schooner            25        0             0      0        0
##   scissors             0       12             0      0        0
##   windsor_chair        0        0            18      0        0
##   wrench               0        1             0     10        0
##   yin_yang             0        0             0      0       23

Here are samples of the wrong predictions:

par(mfrow = c(4, 4))
id <- which(y_pred != y)
for (i in id) {
  par(mar = rep(0, 4L))
  plot(0,0,xlim=c(0,1),ylim=c(0,1),axes= FALSE,type = "n")
  rasterImage(Z[i,,,] /255,0,0,1,1)
  text(0.5, 0.1, label = class_names[y[i] + 1L], col = "green", cex=2)
  text(0.5, 0.3, label = class_names[y_pred[i] + 1L], col = "red", cex=2)
}

Here be the most representative image for each class according to our model:

y_probs <- predict(model, X)
par(mfrow = c(4, 5))
for (i in 0:19) {
  par(mar = rep(0, 4L))
  plot(0,0,xlim=c(0,1),ylim=c(0,1),axes= FALSE,type = "n")
  j <- order(y_probs[,i+1], decreasing = TRUE)[1]
  rasterImage(Z[j,,,]/255,0,0,1,1)
  text(0.5, 0.1, class_names[i+1], cex = 2, col = "salmon")
}

The results show that the transfer learning model was very effective. The images that were misclassified could even be misclassified by a human.

In conclusion, transfer learning has allowed us to build a customized model with very little data.